{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Convnet feature extractor with Keras 2.0\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "2017-12-06, Keunwoo Choi (keunwoo.choi@qmul.ac.uk)\n",
    "\n",
    "Because the previous release with keras 1.2.x has non-consistency problem by varying batch size, I re-ran a similar but slightly different convnet. \n",
    "\n",
    "### What's similar\n",
    " * 5-layer convnet\n",
    " * Average-pooling based feature computation\n",
    " * Trained for music tagging\n",
    "\n",
    "#### What's the same but actually bad\n",
    " * The current version of kapre has a slight bug in computing melspectrogram. \n",
    "\n",
    "### What's better\n",
    " * Stable (or consistent) feature prediction regardless of the batch size\n",
    " * Works for all signal > 5.12s\n",
    " * Uses a data sample-based normalization as a audio level normalizer, just in case.\n",
    "\n",
    "### What might be worse\n",
    " * It's basically different weights, so different features. I guess they'de be pretty similar though\n",
    " * Tagging performance is 0.825 AUC, which is not as good as the prev one (0.845 AUC)\n",
    " "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# USAGE\n",
    "\n",
    "## Load model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "from keras import backend as K\n",
    "from keras.models import Model\n",
    "from keras.layers import GlobalAveragePooling2D as GAP2D\n",
    "from keras.layers import concatenate as concat"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "channels_last\n",
      "tensorflow\n",
      "2.0.6\n",
      "0.1.2.1\n"
     ]
    }
   ],
   "source": [
    "import keras\n",
    "import kapre\n",
    "\n",
    "print(keras.backend.image_data_format())\n",
    "print(keras.backend.backend())\n",
    "print(keras.__version__)\n",
    "print(kapre.__version__)  # 0c37638"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "int_axis=0 passed but is ignored, str_axis is used instead.\n"
     ]
    }
   ],
   "source": [
    "model = keras.models.load_model('model_best.hdf5', \n",
    "                                custom_objects={'Melspectrogram':kapre.time_frequency.Melspectrogram,\n",
    "                                                'Normalization2D':kapre.utils.Normalization2D})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "_________________________________________________________________\n",
      "Layer (type)                 Output Shape              Param #   \n",
      "=================================================================\n",
      "melgram (Melspectrogram)     (None, 96, None, 1)       287840    \n",
      "_________________________________________________________________\n",
      "normalization2d_1 (Normaliza (None, 96, None, 1)       0         \n",
      "_________________________________________________________________\n",
      "conv2d_1 (Conv2D)            (None, 96, None, 32)      320       \n",
      "_________________________________________________________________\n",
      "batch_normalization_1 (Batch (None, 96, None, 32)      128       \n",
      "_________________________________________________________________\n",
      "elu_1 (ELU)                  (None, 96, None, 32)      0         \n",
      "_________________________________________________________________\n",
      "MP_1 (MaxPooling2D)          (None, 48, None, 32)      0         \n",
      "_________________________________________________________________\n",
      "conv2d_2 (Conv2D)            (None, 48, None, 32)      9248      \n",
      "_________________________________________________________________\n",
      "batch_normalization_2 (Batch (None, 48, None, 32)      128       \n",
      "_________________________________________________________________\n",
      "elu_2 (ELU)                  (None, 48, None, 32)      0         \n",
      "_________________________________________________________________\n",
      "MP_2 (MaxPooling2D)          (None, 16, None, 32)      0         \n",
      "_________________________________________________________________\n",
      "conv2d_3 (Conv2D)            (None, 16, None, 32)      9248      \n",
      "_________________________________________________________________\n",
      "batch_normalization_3 (Batch (None, 16, None, 32)      128       \n",
      "_________________________________________________________________\n",
      "elu_3 (ELU)                  (None, 16, None, 32)      0         \n",
      "_________________________________________________________________\n",
      "MP_3 (MaxPooling2D)          (None, 8, None, 32)       0         \n",
      "_________________________________________________________________\n",
      "conv2d_4 (Conv2D)            (None, 8, None, 32)       9248      \n",
      "_________________________________________________________________\n",
      "batch_normalization_4 (Batch (None, 8, None, 32)       128       \n",
      "_________________________________________________________________\n",
      "elu_4 (ELU)                  (None, 8, None, 32)       0         \n",
      "_________________________________________________________________\n",
      "MP_4 (MaxPooling2D)          (None, 4, None, 32)       0         \n",
      "_________________________________________________________________\n",
      "conv2d_5 (Conv2D)            (None, 4, None, 32)       9248      \n",
      "_________________________________________________________________\n",
      "batch_normalization_5 (Batch (None, 4, None, 32)       128       \n",
      "_________________________________________________________________\n",
      "elu_5 (ELU)                  (None, 4, None, 32)       0         \n",
      "_________________________________________________________________\n",
      "MP_5 (MaxPooling2D)          (None, 1, None, 32)       0         \n",
      "_________________________________________________________________\n",
      "MP (GlobalMaxPooling2D)      (None, 32)                0         \n",
      "_________________________________________________________________\n",
      "dense_1 (Dense)              (None, 50)                1650      \n",
      "=================================================================\n",
      "Total params: 327,442\n",
      "Trainable params: 39,282\n",
      "Non-trainable params: 288,160\n",
      "_________________________________________________________________\n"
     ]
    }
   ],
   "source": [
    "model.summary()\n",
    "# MP: [(2, 4), (3, 4), (2, 5), (2, 4), (4, 4)]\n",
    "# --> Melspectrogram: should have more than 320 time frames.\n",
    "# --> Input audio should be >= 81,920 [samples] with sampling rate=16000.\n",
    "#     I.e., longer than 5.12 [seconds]."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Build a feature extractor based on the trained model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "feat_layer1 = GAP2D()(model.get_layer('elu_1').output)\n",
    "feat_layer2 = GAP2D()(model.get_layer('elu_2').output)\n",
    "feat_layer3 = GAP2D()(model.get_layer('elu_3').output)\n",
    "feat_layer4 = GAP2D()(model.get_layer('elu_4').output)\n",
    "feat_layer5 = GAP2D()(model.get_layer('elu_5').output)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [],
   "source": [
    "feat_all = concat([feat_layer1, feat_layer2, feat_layer3, feat_layer4, feat_layer5])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [],
   "source": [
    "feat_extractor = Model(inputs=model.input, outputs=feat_all)\n",
    "# You just build the feature extractor. \n",
    "# This is a keras model, you can use .predict() method to get the features as below."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "__________________________________________________________________________________________\n",
      "Layer (type)                 Output Shape        Param #    Connected to                  \n",
      "==========================================================================================\n",
      "melgram_input (InputLayer)   (None, 1, None)     0                                        \n",
      "__________________________________________________________________________________________\n",
      "melgram (Melspectrogram)     (None, 96, None, 1) 287840     melgram_input[0][0]           \n",
      "__________________________________________________________________________________________\n",
      "normalization2d_1 (Normaliza (None, 96, None, 1) 0          melgram[0][0]                 \n",
      "__________________________________________________________________________________________\n",
      "conv2d_1 (Conv2D)            (None, 96, None, 32 320        normalization2d_1[0][0]       \n",
      "__________________________________________________________________________________________\n",
      "batch_normalization_1 (Batch (None, 96, None, 32 128        conv2d_1[0][0]                \n",
      "__________________________________________________________________________________________\n",
      "elu_1 (ELU)                  (None, 96, None, 32 0          batch_normalization_1[0][0]   \n",
      "__________________________________________________________________________________________\n",
      "MP_1 (MaxPooling2D)          (None, 48, None, 32 0          elu_1[0][0]                   \n",
      "__________________________________________________________________________________________\n",
      "conv2d_2 (Conv2D)            (None, 48, None, 32 9248       MP_1[0][0]                    \n",
      "__________________________________________________________________________________________\n",
      "batch_normalization_2 (Batch (None, 48, None, 32 128        conv2d_2[0][0]                \n",
      "__________________________________________________________________________________________\n",
      "elu_2 (ELU)                  (None, 48, None, 32 0          batch_normalization_2[0][0]   \n",
      "__________________________________________________________________________________________\n",
      "MP_2 (MaxPooling2D)          (None, 16, None, 32 0          elu_2[0][0]                   \n",
      "__________________________________________________________________________________________\n",
      "conv2d_3 (Conv2D)            (None, 16, None, 32 9248       MP_2[0][0]                    \n",
      "__________________________________________________________________________________________\n",
      "batch_normalization_3 (Batch (None, 16, None, 32 128        conv2d_3[0][0]                \n",
      "__________________________________________________________________________________________\n",
      "elu_3 (ELU)                  (None, 16, None, 32 0          batch_normalization_3[0][0]   \n",
      "__________________________________________________________________________________________\n",
      "MP_3 (MaxPooling2D)          (None, 8, None, 32) 0          elu_3[0][0]                   \n",
      "__________________________________________________________________________________________\n",
      "conv2d_4 (Conv2D)            (None, 8, None, 32) 9248       MP_3[0][0]                    \n",
      "__________________________________________________________________________________________\n",
      "batch_normalization_4 (Batch (None, 8, None, 32) 128        conv2d_4[0][0]                \n",
      "__________________________________________________________________________________________\n",
      "elu_4 (ELU)                  (None, 8, None, 32) 0          batch_normalization_4[0][0]   \n",
      "__________________________________________________________________________________________\n",
      "MP_4 (MaxPooling2D)          (None, 4, None, 32) 0          elu_4[0][0]                   \n",
      "__________________________________________________________________________________________\n",
      "conv2d_5 (Conv2D)            (None, 4, None, 32) 9248       MP_4[0][0]                    \n",
      "__________________________________________________________________________________________\n",
      "batch_normalization_5 (Batch (None, 4, None, 32) 128        conv2d_5[0][0]                \n",
      "__________________________________________________________________________________________\n",
      "elu_5 (ELU)                  (None, 4, None, 32) 0          batch_normalization_5[0][0]   \n",
      "__________________________________________________________________________________________\n",
      "global_average_pooling2d_2 ( (None, 32)          0          elu_1[0][0]                   \n",
      "__________________________________________________________________________________________\n",
      "global_average_pooling2d_3 ( (None, 32)          0          elu_2[0][0]                   \n",
      "__________________________________________________________________________________________\n",
      "global_average_pooling2d_4 ( (None, 32)          0          elu_3[0][0]                   \n",
      "__________________________________________________________________________________________\n",
      "global_average_pooling2d_5 ( (None, 32)          0          elu_4[0][0]                   \n",
      "__________________________________________________________________________________________\n",
      "global_average_pooling2d_6 ( (None, 32)          0          elu_5[0][0]                   \n",
      "__________________________________________________________________________________________\n",
      "concatenate_2 (Concatenate)  (None, 160)         0          global_average_pooling2d_2[0][\n",
      "                                                            global_average_pooling2d_3[0][\n",
      "                                                            global_average_pooling2d_4[0][\n",
      "                                                            global_average_pooling2d_5[0][\n",
      "                                                            global_average_pooling2d_6[0][\n",
      "==========================================================================================\n",
      "Total params: 325,792\n",
      "Trainable params: 37,632\n",
      "Non-trainable params: 288,160\n",
      "__________________________________________________________________________________________\n"
     ]
    }
   ],
   "source": [
    "# model.summary() is always right.\n",
    "feat_extractor.summary(line_length=90)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 50,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# Let's see the feature is really stable per batch size\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 51,
   "metadata": {},
   "outputs": [],
   "source": [
    "inp = np.random.random((12, 1, 16000*6))  # a mono, 6-second input signals with batch_size=12\n",
    "feat = feat_extractor.predict(inp, batch_size=12)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 52,
   "metadata": {},
   "outputs": [],
   "source": [
    "feat2 = feat_extractor.predict(inp[0:1]) # what if the batch size is 1?"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 53,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([ True,  True,  True,  True,  True,  True,  True,  True,  True,\n",
       "        True,  True,  True,  True,  True,  True,  True,  True,  True,\n",
       "        True,  True,  True,  True,  True,  True,  True,  True,  True,\n",
       "        True,  True,  True,  True,  True,  True,  True,  True,  True,\n",
       "        True,  True,  True,  True,  True,  True,  True,  True,  True,\n",
       "        True,  True,  True,  True,  True,  True,  True,  True,  True,\n",
       "        True,  True,  True,  True,  True,  True,  True,  True,  True,\n",
       "        True,  True,  True,  True,  True,  True,  True,  True,  True,\n",
       "        True,  True,  True,  True,  True,  True,  True,  True,  True,\n",
       "        True,  True,  True,  True,  True,  True,  True,  True,  True,\n",
       "        True,  True,  True,  True,  True,  True,  True,  True,  True,\n",
       "        True,  True,  True,  True,  True,  True,  True,  True,  True,\n",
       "        True,  True,  True,  True,  True,  True,  True,  True,  True,\n",
       "        True,  True,  True,  True,  True,  True,  True,  True,  True,\n",
       "        True,  True,  True,  True,  True,  True,  True,  True,  True,\n",
       "        True,  True,  True,  True,  True,  True,  True,  True,  True,\n",
       "        True,  True,  True,  True,  True,  True,  True,  True,  True,\n",
       "        True,  True,  True,  True,  True,  True,  True], dtype=bool)"
      ]
     },
     "execution_count": 53,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.isclose(feat[0], feat2[0], rtol=1e-5, atol=1e-4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 2",
   "language": "python",
   "name": "python2"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
